SVM (support vector machine 支援向量機),是在特徵空間中找到一個分離超平面,也就是「決策邊界」(decision boundary)。
我們可以透過這個決策邊界(紅色虛線)將資料分成不同類別,最佳化的目標是「邊界」(margin)。
離灰線最近的是支援向量(support vector)。
當然不是每個情況都可以靠這種一刀兩斷的方式進行分類,SVM 可以將資料投影到高維度空間,在高維度空間找到超平面進行分割。
通常使用的 kernal 核心轉換有 RBF Kernal 高斯轉換、高次方轉換等等
資料來源
這次我想要預測顧客最後會不會購買書,可以使用的特徵因子有性別、年齡、薪水及是否為VIP註記。這邊搭配 PCA 做降維,選擇 n_components=2(最後才可以畫圖呀XDD)不做降維不好畫圖啊!
# 訓練線性 SVM 並預測結果
kernel = 'linear'
model = SVC(kernel=kernel)
model.fit(dx_train, dy_train)
predict = model.predict(dx_test)
test_score = model.score(dx_test, dy_test) * 100
plt.figure(figsize=(8, 8))
plt.rcParams['font.size'] = 14
plt.title(f'SVM {kernel} (accuracy={test_score:.1f}%)')
plt.scatter(*dx_test.T, c=predict, cmap='tab10', s=100)
plt.scatter(*dx_test.T, c=dy_test, cmap='Set3', s=35)
# 求出超平面與邊界
x_min = np.amin(dx_test.T[0])
x_max = np.amax(dx_test.T[0])
y_min = np.amin(dx_test.T[1])
y_max = np.amax(dx_test.T[1])
XX, YY = np.mgrid[x_min:x_max:200j, y_min:y_max:200j]
Z = model.decision_function(
np.c_[XX.ravel(), YY.ravel()]).reshape(XX.shape)
# 畫出超平面與邊界
plt.contour(XX, YY, Z, colors=['grey', 'coral', 'grey'],
linestyles=['--', '-', '--'], linewidths=[2, 2, 2],
levels=[-1, 0, 1])
plt.grid(True)
plt.xlim([x_min, x_max])
plt.ylim([y_min, y_max])
plt.tight_layout()
plt.show()
在 kernel 還可以選擇 poly、sigmoid,最後發現 rbf 成效最好!
# 訓練非線性 SVM 並預測結果
kernel = 'rbf'
model = SVC(kernel=kernel)
model.fit(dx_train, dy_train)
predict = model.predict(dx_test)
test_score = model.score(dx_test, dy_test) * 100
plt.figure(figsize=(8, 8))
plt.rcParams['font.size'] = 14
plt.title(f'SVM {kernel} (accuracy={test_score:.1f}%)')
plt.scatter(*dx_test.T, c=predict, cmap='tab10', s=100)
plt.scatter(*dx_test.T, c=dy_test, cmap='Set3', s=35)
# 求出超平面與邊界
x_min = np.amin(dx_test.T[0])
x_max = np.amax(dx_test.T[0])
y_min = np.amin(dx_test.T[1])
y_max = np.amax(dx_test.T[1])
XX, YY = np.mgrid[x_min:x_max:200j, y_min:y_max:200j]
Z = model.decision_function(
np.c_[XX.ravel(), YY.ravel()]).reshape(XX.shape)
# 畫出超平面與邊界
plt.contour(XX, YY, Z, colors=['grey', 'coral', 'grey'],
linestyles=['--', '-', '--'], linewidths=[2, 2, 2],
levels=[-1, 0, 1])
plt.grid(True)
plt.xlim([x_min, x_max])
plt.ylim([y_min, y_max])
plt.tight_layout()
plt.show()
更詳細可以請參考連結